{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PyMC4 developer guide\n",
    "\n",
    "PyMC4 is based on TensorFlow Probability. This provides many benefits including a rich library of probability distributions and inference algorithms. In order to use these inference algorithms, we must provide a tensor-in-tensor-out logp function. The input tensors are placeholder tensors on which a value for a random variable (RV) can be set, and the output is a tensor of the logp of the model evaluated of these inputs. During sampling, for example, the sampler would propose different input values by updating the input tensors and compute the corresponding model logp and gradients of the logp (which is provided via autodiff of the output tensor from the TensorFlow library).\n",
    "\n",
    "This provides a challenge for PyMC4 because in order to know which input tensors to construct, we need to evaluate the users' model. For example, let's consider the (non-sensical) model:\n",
    "\n",
    "```python\n",
    "x = pm4.Normal(0, 1)\n",
    "y = pm4.Normal(x**2, 1)\n",
    "```\n",
    "\n",
    "This model should familiar to anyone who knows `PyMC3` or `STAN`. However, it hides some subtle complexity. We need to keep track of two things simultaneously here. `x` and `y` as RVs as well as the log probabilities  of `x` and `y` evaluated over some input values. That is why in TensorFlow probability you need to write your model twice: once for its RVs, and once for its logps.\n",
    "\n",
    "As `PyMC4` is supposed to focus on UX we want to hide this complexity from users. This, however, creates a chicken-and-egg type of problem: in order to create the tensor-in-tensor-out log probability function, we need to replace the RVs in the user's model with the input-tensors. But without having run the model, we don't know which input-tensors we need to create in the first place.\n",
    "\n",
    "The approach we took is to have the user define a decorated function which we call a **model template**. Template, because it's not the model itself, but rather, the instruction of how to build a model. This template function gets called **twice** under different contexts which alter how an RV gets converted to a tensor:\n",
    "1. Under `ForwardContext`: This collects the RV names, the RVs get converted to sample tensors which allow forward sampling.\n",
    "2. Under `InferenceContext`: Now when an RV gets converted to a tensor, it gets replaced with the input-tensor from the log_prob function. When we compute the log_prob for an individual RV, we evaluate it over that tensor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pymc4 as pm\n",
    "import tensorflow as tf\n",
    "from scipy import stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = pm.Normal(mu=0, sigma=1, name='x')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So far this is still a `PyMC4` RV object:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<pymc4.random_variables.continuous.Normal at 0x1a28f31080>"
      ]
     },
     "execution_count": 116,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When used in conjuction with TF, `.as_tensor()` gets called. By default, things are evaluated under the `ForwardContext`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=3445, shape=(), dtype=float32, numpy=0.6061011>"
      ]
     },
     "execution_count": 117,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.as_tensor()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is done automatically by TF by registering a `tensor_conversion_function` (in `__init__.py`):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=3449, shape=(), dtype=float32, numpy=0.6061011>"
      ]
     },
     "execution_count": 118,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "1 * x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6061011"
      ]
     },
     "execution_count": 122,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Value sampled from a normal distribution\n",
    "sample = x.as_tensor().numpy()\n",
    "sample"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once we require the logp, we call `.log_prob()` which evaluates the distribution *over itself*:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=3512, shape=(), dtype=float32, numpy=-1.1026177>"
      ]
     },
     "execution_count": 123,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.log_prob()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1.1026178022947524"
      ]
     },
     "execution_count": 124,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stats.norm().logpdf(sample)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the context of a model, we can see what is going on more clearly by doing things manually:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [],
   "source": [
    "@pm.model\n",
    "def model_template():\n",
    "    print('model is called')\n",
    "    x = pm.Normal(mu=0, sigma=1, name='x')\n",
    "    print('x is type', x)\n",
    "    print('x viewed as a tensor is', x.as_tensor())\n",
    "    print('x**2 is', x**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model is called\n",
      "x is type <pymc4.random_variables.continuous.Normal object at 0x1a28f4c7f0>\n",
      "x viewed as a tensor is tf.Tensor(-0.68766874, shape=(), dtype=float32)\n",
      "x**2 is tf.Tensor(0.4728883, shape=(), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "model = model_template.configure()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<pymc4.random_variables.continuous.Normal at 0x1a28f4c7f0>]"
      ]
     },
     "execution_count": 127,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# So far, only the forward context is created with the registered RVs.\n",
    "model._forward_context.vars"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(-0.68766874, shape=(), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "with model._forward_context:\n",
    "    print(model._forward_context.vars[0].as_tensor())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create our tensor-in tensor-out logp function\n",
    "logp_func = model.make_log_prob_function()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model is called\n",
      "x is type <pymc4.random_variables.continuous.Normal object at 0x1a28f31f98>\n",
      "x viewed as a tensor is tf.Tensor([2. 2. 2. 2. 2. 2. 2. 2. 2. 2.], shape=(10,), dtype=float32)\n",
      "x**2 is tf.Tensor([4. 4. 4. 4. 4. 4. 4. 4. 4. 4.], shape=(10,), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "x = tf.ones((10,)) * 2\n",
    "logp_tensor = logp_func(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Something important has happened the second time the model template was called. Instead of the `RV` when converted to a tensor becoming a tensor-value sampled from a normal distribution as the first time, it now converts to the input-tensor. \n",
    "\n",
    "In a second step, we loop over the created RVs, compute their log probability tensors, and sum them:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=3695, shape=(), dtype=float32, numpy=-29.189386>"
      ]
     },
     "execution_count": 132,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logp_tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This happens inside of `pymc4.Model.make_log_prob_function()`:\n",
    "\n",
    "```python\n",
    "def make_log_prob_function(self):\n",
    "    \"\"\"Return the log probability of the model.\"\"\"\n",
    "\n",
    "    def log_prob(*args):\n",
    "        # Create the new context with the RVs\n",
    "        # collected from the first pass where\n",
    "        # it was stored in `self._forward_context`,\n",
    "        # and link them to the input tensors provided \n",
    "        # by args.\n",
    "        context = contexts.InferenceContext(\n",
    "            args, \n",
    "            expected_vars=self._forward_context.vars)\n",
    "        with context:\n",
    "            # Call the model-template again, this time\n",
    "            # RVs will evaluate to the input-tensors\n",
    "            self._evaluate()\n",
    "            # `var.log_prob()` will evaluate the\n",
    "            # distribution's logp over the input-tensor\n",
    "            return sum(tf.reduce_sum(var.log_prob()) for var in context.vars)\n",
    "\n",
    "    # Return our tensor-in-tensor-out function\n",
    "    return log_prob\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
